import torch
import pytorch3d
import torch.nn.functional as F

from pytorch3d.ops import interpolate_face_attributes

from pytorch3d.renderer import (
    look_at_view_transform,
    FoVPerspectiveCameras,
    AmbientLights,
    PointLights,
    DirectionalLights,
    Materials,
    RasterizationSettings,
    MeshRenderer,
    MeshRasterizer,
    SoftPhongShader,
    SoftSilhouetteShader,
    HardPhongShader,
    TexturesVertex,
    TexturesUV,
    Materials,
)
from pytorch3d.renderer.blending import BlendParams, hard_rgb_blend
from pytorch3d.renderer.utils import convert_to_tensors_and_broadcast, TensorProperties
from pytorch3d.renderer.mesh.shader import ShaderBase


def get_cos_angle(points, normals, camera_position):
    """
    calculate cosine similarity between view->surface and surface normal.
    """

    if points.shape != normals.shape:
        msg = "Expected points and normals to have the same shape: got %r, %r"
        raise ValueError(msg % (points.shape, normals.shape))

    # Ensure all inputs have same batch dimension as points
    matched_tensors = convert_to_tensors_and_broadcast(
        points, camera_position, device=points.device
    )
    _, camera_position = matched_tensors

    # Reshape direction and color so they have all the arbitrary intermediate
    # dimensions as points. Assume first dim = batch dim and last dim = 3.
    points_dims = points.shape[1:-1]
    expand_dims = (-1,) + (1,) * len(points_dims)

    if camera_position.shape != normals.shape:
        camera_position = camera_position.view(expand_dims + (3,))

    normals = F.normalize(normals, p=2, dim=-1, eps=1e-6)

    # Calculate the cosine value.
    view_direction = camera_position - points
    view_direction = F.normalize(view_direction, p=2, dim=-1, eps=1e-6)
    cos_angle = torch.sum(view_direction * normals, dim=-1, keepdim=True)
    cos_angle = cos_angle.clamp(0, 1)

    # Cosine of the angle between the reflected light ray and the viewer
    return cos_angle


def _geometry_shading_with_pixels(
    meshes, fragments, lights, cameras, materials, texels
):
    """
    Render pixel space vertex position, normal(world), depth, and cos angle

    Args:
            meshes: Batch of meshes
            fragments: Fragments named tuple with the outputs of rasterization
            lights: Lights class containing a batch of lights
            cameras: Cameras class containing a batch of cameras
            materials: Materials class containing a batch of material properties
            texels: texture per pixel of shape (N, H, W, K, 3)

    Returns:
            colors: (N, H, W, K, 3)
            pixel_coords: (N, H, W, K, 3), camera coordinates of each intersection.
    """
    verts = meshes.verts_packed()  # (V, 3)
    faces = meshes.faces_packed()  # (F, 3)
    vertex_normals = meshes.verts_normals_packed()  # (V, 3)
    faces_verts = verts[faces]
    faces_normals = vertex_normals[faces]
    pixel_coords_in_camera = interpolate_face_attributes(
        fragments.pix_to_face, fragments.bary_coords, faces_verts
    )
    pixel_normals = interpolate_face_attributes(
        fragments.pix_to_face, fragments.bary_coords, faces_normals
    )

    cos_angles = get_cos_angle(
        pixel_coords_in_camera, pixel_normals, cameras.get_camera_center()
    )

    return pixel_coords_in_camera, pixel_normals, fragments.zbuf[..., None], cos_angles


class HardGeometryShader(ShaderBase):
    """
    renders common geometric informations.


    """

    def forward(self, fragments, meshes, **kwargs):
        cameras = super()._get_cameras(**kwargs)
        texels = self.texel_from_uv(fragments, meshes)

        lights = kwargs.get("lights", self.lights)
        materials = kwargs.get("materials", self.materials)
        blend_params = kwargs.get("blend_params", self.blend_params)
        verts, normals, depths, cos_angles = _geometry_shading_with_pixels(
            meshes=meshes,
            fragments=fragments,
            texels=texels,
            lights=lights,
            cameras=cameras,
            materials=materials,
        )
        texels = meshes.sample_textures(fragments)
        verts = hard_rgb_blend(verts, fragments, blend_params)
        normals = hard_rgb_blend(normals, fragments, blend_params)
        depths = hard_rgb_blend(depths, fragments, blend_params)
        cos_angles = hard_rgb_blend(cos_angles, fragments, blend_params)
        from IPython import embed

        embed()
        texels = hard_rgb_blend(texels, fragments, blend_params)
        return verts, normals, depths, cos_angles, texels, fragments

    def texel_from_uv(self, fragments, meshes):
        texture_tmp = meshes.textures
        maps_tmp = texture_tmp.maps_padded()
        uv_color = [[[1, 0], [1, 1]], [[0, 0], [0, 1]]]
        uv_color = (
            torch.FloatTensor(uv_color).to(maps_tmp[0].device).type(maps_tmp[0].dtype)
        )
        uv_texture = TexturesUV(
            [uv_color.clone() for t in maps_tmp],
            texture_tmp.faces_uvs_padded(),
            texture_tmp.verts_uvs_padded(),
            sampling_mode="bilinear",
        )
        meshes.textures = uv_texture
        texels = meshes.sample_textures(fragments)
        meshes.textures = texture_tmp
        texels = torch.cat((texels, texels[..., -1:] * 0), dim=-1)
        return texels
